import copy

from torchmetrics import MaxMetric
from torchmetrics import Accuracy
from torchmetrics import ConfusionMatrix
import pytorch_lightning as pl
import torch
from torch import nn
from omegaconf import OmegaConf
from hydra.utils import instantiate
import numpy as np
import wandb
from pathlib import Path
import os
import logging
import wandb

log = logging.getLogger(__name__)

class ClassificationModel(pl.LightningModule):
    def __init__(self, cfg,learner_model: nn.Module=None):
        super().__init__()
        # self.num_classes = cfg.model.model.num_classes
        self.num_classes =  cfg.model.model.num_classes # 10 #TODO

        self.save_hyperparameters(cfg)
        # print(OmegaConf.to_yaml(self.hparams))
        self.same_initial = cfg.same_initial


        # print(type(self.model))



        if self.same_initial:
            # if learner_model:
            #     self.initial_model = learner_model
            #     self.model : nn.Module = copy.deepcopy(self.initial_model)
            #     self.model.requires_grad_(requires_grad=True)
            #     self.model.train()
            #     log.info("using the provided model as a an initial model")
            # else:
            self.initial_model = instantiate(self.hparams.model.model)
            print(">>>>>>>>>>>>>>>>>>>>>>>>> 1-len(self.initial_model.state_dict().items())",
                  len(self.initial_model.state_dict().items()))

            self.model: nn.Module = copy.deepcopy(self.initial_model)
            print("2- len(self.model.state_dict().items())",
                  len(self.model.state_dict().items()))
            self.model.requires_grad_(requires_grad=True)
            self.model.train()
            log.info("using the same initial model for all models")


        else:
            # self.model = instantiate(self.hparams.model)
            # self.model = instantiate(self.hparams.model.model)
            print(f"self.hparams.model: {self.hparams.model}")
            self.model = instantiate(self.hparams.model) #TODO: remove this and return the prev line
            print(">>>>>>>>>>>>>> 1'- len(self.model.state_dict().items())",
                  len(self.model.state_dict().items()))
            log.info("initializing a new model..")

        # Current client number
        self.current_client_idx = 0

        # loss function
        self.loss = nn.CrossEntropyLoss()

        # use separate metric instance for train, val and test step
        # to ensure a proper reduction over the epoch
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

        self.per_class_test_acc = []
        self.per_class_val_acc = []


        self.train_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)  # TODO: fix this to be filled automatically
        self.val_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)  # TODO: fix this to be filled automatically
        self.test_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)  # TODO: fix this to be filled automatically

        # for logging best so far validation accuracy
        self.val_acc_best = MaxMetric()

    def next_client(self):
        # path = Path(f'{os.getcwd()}/models/')
        # path.mkdir(parents=True, exist_ok=True)
        # file_path = path / f'client-{self.current_client_idx}.pth'  # best_val_acc-{self.val_acc_best.compute()}
        # torch.save(
        #     self.model.state_dict(),
        #     file_path
        # )
        # log.info(f"Saved client-{self.current_client_idx} at {path} locally")
        # artifact = wandb.Artifact(name=f"clients_models", type="model")
        # artifact.add_file(str(file_path), name=f'client-{self.current_client_idx}.pth')
        # log.info(f"Saved client-{self.current_client_idx} as an Artifact at WANDB")
        self.current_client_idx += 1
        # reinit
        if (self.hparams.same_initial):
            self.model = copy.deepcopy(self.initial_model)
            print("3- len(self.model.state_dict().items())",
                  len(self.model.state_dict().items()))
        else:
            self.model = instantiate(self.hparams.model.model)
            print("3'-len(self.model.state_dict().items())",
                  len(self.model.state_dict().items()))

    def forward(self, x):
        return self.model(x)

    def step(self, batch):
        x, y = batch
        # print(f"x: {x}")
        # print(f"y : {y}")
        # print(f">>>>>>>>>>>>>>>>>type of y: {y.dtype}<<<<<<<<<<<<<<")
        if len(y.shape) > 1:
            y = y.squeeze(1)  # Squeeze the labels to ensure they are 1D  (for medMNIST dataset)

        logits = self(x)
        loss = self.loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        # print(f"preds: {preds}")
        return loss, preds, y

    def training_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        # log train metrics
        acc = self.train_acc(preds, targets)
        # conf_mat = self.train_confusion_matrix(preds, targets)
        self.log(f"train_client-{self.current_client_idx}/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log(f"train_client-{self.current_client_idx}/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        # self.log(
        #     f"train_client-{self.current_client_idx}/confusion_matrix",
        #     conf_mat, on_step=False, on_epoch=True, prog_bar=False
        # )
        # we can return here dict with any tensors
        # and then read it in some callback or in `training_epoch_end()`` below
        # remember to always return loss from `training_step()` or else backpropagation will fail!
        return {"loss": loss, "preds": preds, "targets": targets}

    def validation_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        self.val_confusion_matrix(preds, targets)

        # log val metrics
        acc = self.val_acc(preds, targets)
        # conf_mat = self.val_confusion_matrix(preds, targets)
        self.log(f"val_client-{self.current_client_idx}/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log(f"val_client-{self.current_client_idx}/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        # self.log(
        #     f"val_client-{self.current_client_idx}/confusion_matrix",
        #     conf_mat, on_step=False, on_epoch=True, prog_bar=False
        # )
        return {"loss": loss, "preds": preds, "targets": targets}

    def validation_epoch_end(self, outputs):
        acc = self.val_acc.compute()  # get val accuracy from current epoch
        self.val_acc_best.update(acc)
        self.log(f"val_client-{self.current_client_idx}/acc_best", self.val_acc_best.compute(), on_epoch=True,
                 prog_bar=True)

        confusion_matrix = self.val_confusion_matrix.compute()
        # print(f"confusion_matrix:{confusion_matrix}")
        # print(f"confusion_matrix_diag:{np.diag(confusion_matrix.cpu().detach().numpy())}")
        # print(f"confusion_matrix.sum(axis=1): {confusion_matrix.sum(axis=1)}")

        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)

        self.per_class_val_acc = np.diag(confusion_matrix.cpu().detach().numpy())
        # print(f"normalized confusion_matrix_diag:{np.diag(confusion_matrix.cpu().detach().numpy())}")

        # conf matrix
        self.logger.experiment.log(
            {
                f"val_client-{self.current_client_idx}/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=np.concatenate([output['targets'].cpu().numpy() for output in outputs]).ravel(),
                    preds=np.concatenate([output['preds'].cpu().numpy() for output in outputs]).ravel(),
                    class_names=None
                )
            },
            # step=self.global_step,
            commit=False
        )

        self.logger.experiment.summary[
            f"client-{self.current_client_idx}/val_acc"
        ] = acc

        self.logger.experiment.summary[
            f"client-{self.current_client_idx}/val_per_class_acc"
        ] = self.per_class_val_acc

        log.info(f"per_class_val_acc = {self.per_class_val_acc}")
        log.info(f">> val_acc = {acc}")
        self.val_acc.reset()
        log.info(f">> val_reset")
        self.val_confusion_matrix.reset()



    def test_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        # log test metrics
        acc = self.test_acc(preds, targets)

        self.test_confusion_matrix(preds, targets)



        # conf_mat = self.test_confusion_matrix(preds, targets)
        self.log(f"test_client-{self.current_client_idx}/loss", loss, on_step=False, on_epoch=True)
        self.log(f"test_client-{self.current_client_idx}/acc", acc, on_step=False, on_epoch=True)
        self.log(f"client-{self.current_client_idx}_best-test-acc", acc, on_step=False, on_epoch=True)

        # self.log(
        #     f"test_client-{self.current_client_idx}/confusion_matrix",
        #     conf_mat, on_step=False, on_epoch=True, prog_bar=False
        # )

        return {"loss": loss, "preds": preds, "targets": targets}

    def test_epoch_end(self, outputs):
        acc = self.test_acc.compute()
        confusion_matrix = self.test_confusion_matrix.compute()

        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
        self.per_class_test_acc = np.diag(confusion_matrix.cpu().detach().numpy())

        self.logger.experiment.log(
            {
                f"test_client-{self.current_client_idx}/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=np.concatenate([output['targets'].cpu().numpy() for output in outputs]).ravel(),
                    preds=np.concatenate([output['preds'].cpu().numpy() for output in outputs]).ravel(),
                    class_names=None
                )
            },
            # step=self.global_step,
            commit=False
        )
        self.logger.experiment.summary[
            f"client-{self.current_client_idx}/per_class_test_acc"
        ] = self.per_class_test_acc

        self.logger.experiment.summary[
            f"client-{self.current_client_idx}/test_acc"
        ] = acc

        log.info(
            f"client-{self.current_client_idx}/per_class_test_acc: {self.per_class_test_acc}")
        log.info(
            f"client-{self.current_client_idx}/test_acc: {acc}")



    def on_epoch_end(self):
        # reset metrics at the end of every epoch
        self.train_acc.reset()
        self.test_acc.reset()
        self.val_acc.reset()
        self.train_confusion_matrix.reset()
        self.val_confusion_matrix.reset()
        self.test_confusion_matrix.reset()

    def on_test_end(self): # on fit end will cause issues since next client reinit the neural network weights
        self.next_client()
        self.val_acc_best.reset()
        print("4-len(self.model.state_dict().items())",
              len(self.model.state_dict().items()))

    def configure_optimizers(self):
        optim = instantiate(config=self.hparams.optim.optim, params=self.parameters())
        return optim